# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from enum import Enum
from typing import Any
from typing import Literal
from typing import Self
from typing import TypeAlias

from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator

from nat.plugins.data_flywheel.observability.schema.schema_registry import register_schema

from .contract_version import ContractVersion

logger = logging.getLogger(__name__)


class FinishReason(str, Enum):
    """Finish reason for chat completion responses."""

    STOP = "stop"
    LENGTH = "length"
    TOOL_CALLS = "tool_calls"


class Function(BaseModel):
    """Function call structure used in both requests and responses."""

    name: str = Field(..., description="The name of the function to call.")
    arguments: dict = Field(
        ...,
        description="The arguments to call the function with, as generated by the model in JSON format.",
    )


class ToolCall(BaseModel):
    """Tool call structure used in responses."""

    function: Function = Field(..., description="The function that the model called.")
    type_: Literal["function"] = Field(
        default="function",
        description="The type of the tool. Currently, only `function` is supported.",
        alias="type",
    )


class AssistantMessage(BaseModel):
    """Assistant message structure used in responses."""

    role: Literal["assistant", "ai"] = Field(...,
                                             description="The role of the messages author, in this case `assistant`.")
    content: str | None = Field(default=None, description="The contents of the assistant message.")
    tool_calls: list[ToolCall] | None = Field(
        default=None, description="The tool calls generated by the model, such as function calls.")


class SystemMessage(BaseModel):
    """System message structure used in requests."""

    content: str = Field(..., description="The contents of the system message.")
    role: Literal["system"] = Field(..., description="The role of the messages author, in this case `system`.")


class UserMessage(BaseModel):
    """User message structure used in requests."""

    content: str = Field(..., description="The contents of the user message.")
    role: Literal["user", "human"] = Field(..., description="The role of the messages author, in this case `user`.")


class ToolMessage(BaseModel):
    """Tool message structure used in responses."""

    content: str = Field(..., description="The contents of the tool message.")
    role: Literal["tool"] = Field(..., description="The role of the messages author, in this case `tool`.")
    tool_call_id: str = Field(..., description="Tool call that this message is responding to.")


class FunctionMessage(BaseModel):
    """Function message structure used in responses."""

    content: str | None = Field(default=None, description="The contents of the function message.")
    role: Literal["function", "chain"] = Field(...,
                                               description="The role of the messages author, in this case `function`.")


Message: TypeAlias = SystemMessage | UserMessage | AssistantMessage | ToolMessage | FunctionMessage


class FunctionParameters(BaseModel):
    """Function parameters structure used in responses."""

    properties: dict = Field(..., description="The properties of the function parameters.")
    required: list[str] = Field(..., description="The required properties of the function parameters.")
    type_: Literal["object"] = Field(default="object", description="The type of the function parameters.", alias="type")

    @field_validator("properties")
    @classmethod
    def validate_property_limit(cls, v: dict) -> dict:
        """Enforce 8-property limit for tool parameters (WAR for NIM bug)."""
        if len(v) > 8:
            raise ValueError(f"Tool properties cannot exceed 8 properties. Found {len(v)} properties.")
        return v


class FunctionDetails(BaseModel):
    """Function details structure used in requests."""

    name: str = Field(..., description="The name of the function.")
    description: str = Field(..., description="The description of the function.")
    parameters: FunctionParameters = Field(..., description="The parameters of the function.")


class RequestTool(BaseModel):
    """Request tool structure used in requests."""

    type: Literal["function"] = Field(..., description="The type of the tool.")
    function: FunctionDetails = Field(..., description="The function details.")


class ESRequest(BaseModel):
    """Request structure used in requests."""

    model_config = ConfigDict(extra="allow")  # Allow extra fields

    tools: list[RequestTool] | None = Field(default=None, description="The tool choice.")
    messages: list[Message] | None = Field(default=None, description="The messages.")
    model: str = Field(default="", description="The model to use.")

    # Enhanced fields for better tracking
    temperature: float | None = Field(None, description="Sampling temperature", ge=0.0, le=2.0)

    max_tokens: int | None = Field(None, description="Maximum tokens in response", ge=1)


class ResponseMessage(BaseModel):
    """Response message structure used in responses."""

    role: Literal["assistant"] | None = Field(default=None,
                                              description="The role of the messages author, in this case `assistant`.")
    content: str | None = Field(default=None, description="The contents of the assistant message.")
    tool_calls: list[ToolCall] | None = Field(
        default=None, description="The tool calls generated by the model, such as function calls.")


class ResponseChoice(BaseModel):
    """Response choice structure used in responses."""

    message: ResponseMessage = Field(..., description="A chat completion message generated by the model.")
    finish_reason: FinishReason | None = Field(None, description="Reason for completion finish")
    index: int | None = Field(None, description="Choice index", ge=0)


class Response(BaseModel):
    """Response structure used in responses."""

    model_config = ConfigDict(extra="allow")  # Allow extra fields

    choices: list[ResponseChoice] | None = Field(default=None, description="The choices.")

    # Enhanced fields for better tracking
    id: str | None = Field(None, description="Response ID")

    object: str | None = Field(None, description="Object type")

    created: int | None = Field(None, description="Creation timestamp")

    model: str | None = Field(None, description="Model used for response")

    usage: dict[str, Any] | None = Field(None, description="Token usage information")


@register_schema(name="elasticsearch", version="1.0")
@register_schema(name="elasticsearch", version="1.1")
class DFWESRecord(BaseModel):
    """Data Flywheel Elasticsearch record."""

    model_config = ConfigDict(extra="forbid", validate_assignment=True)

    # Contract versioning
    contract_version: ContractVersion = Field(default=ContractVersion.V1_0,
                                              description="Contract version for compatibility tracking")

    # Core fields (backward compatible)
    request: ESRequest = Field(..., description="The OpenAI ChatCompletion request.")
    response: Response = Field(..., description="The OpenAI ChatCompletion response.")
    client_id: str = Field(..., description="Identifier of the application or deployment that generated traffic.")
    workload_id: str = Field(..., description="Stable identifier for the logical task / route / agent node.")
    timestamp: int = Field(..., description="The timestamp of the payload in seconds since epoch.")

    # Enhanced tracking fields
    error_details: str | None = Field(None, description="Error details if processing failed", max_length=5000)

    @model_validator(mode="after")
    def validate_data_consistency(self) -> Self:
        # Validate tool calling format if tools are
        request_tools = getattr(self.request, "tools", None)
        if request_tools:
            # Check if response has tool calls
            response_choices = getattr(self.response, "choices", [])
            has_tool_calls = any(choice.message and choice.message.tool_calls for choice in response_choices)
            if not has_tool_calls:
                logger.warning("Request has tools but response has no tool calls")

        return self
